#!/usr/bin/env python
from datetime import datetime
import getopt
import os
import pwd
import re
import sys

# global variables stored options value
is_debug = False
is_brief = False
has_print_title = False
log_file = ""
djob_path = ''
# global variables stored options value
job_ids = []
job_names = []
job_states = []
job_users = []
job_partitions = []
job_err_info = []
job_start_time = ''
job_end_time = ''
# display items
SUPPORT_DATA_FORMAT = '%Y/%m/%d %H:%M:%S'
DISPLAY_ID = 'ID'
DISPLAY_NAME = 'NAME'
DISPLAY_USER = 'USER'
DISPLAY_STATE = 'STATE'
DISPLAY_QUEUE = 'QUEUE'
DISPLAY_ALLOC_CPUS = 'ALLOC_CPU'
DISPLAY_EXIT_MESSAGE = 'EXIT_MESSAGE'
DISPLAY_ERR_CODE = 'ERR_CODE'
# sacct display items
DISPLAY_JOB_ID = 'JobID'
DISPLAY_JOB_NAME = 'JobName'
DISPLAY_JOB_STATE = 'State'
DISPLAY_JOB_EXITCODE = 'ExitCode'
DISPLAY_JOB_PARTITION = 'Partition'
DISPLAY_JOB_ACCOUNT = 'Account'
DISPLAY_JOB_ALLOC_CPUS = 'AllocCPUS'
DISPLAY_SHORT_FORMAT = '{:<12} {:>10} {:>8}'
DISPLAY_LONG_FORMAT = '{:<12} {:>10} {:>10} {:>10} {:>10} {:>10} {:>8}'
# index of djob results
JOB_QUERY_ID_INDEX = 0
JOB_QUERY_NAME_INDEX = 1
JOB_QUERY_USER_INDEX = 2
JOB_QUERY_STATE_INDEX = 3
JOB_QUERY_QUEUE_INDEX = 4
JOB_QUERY_IEN = 8
# index of djob -x -CO results
JOB_CUSTOM_ID_INDEX = 0
JOB_CUSTOM_ALLOC_CPUS_INDEX = 1
JOB_CUSTOM_EXIT_CODE = 2
JOB_CUSTOM_EXIT_MESSAGE = 3
JOB_CUSTOM_VALID_LEN = 4
# Length Limit of string displayed
JOB_NAME_LEN_LIMIT = 10
JOB_PARTITION_LEN_LIMIT = 10


def get_help_info():
    help_info = '''Usage: sacct [OPTIONS]
  -b, --brief:                    Equivalent to '--format=jobid,state,error'.
  -S, --starttime:                Select jobs eligible after this time,
                                  default is 00:00:00 of the current day,
                                  is may be used with '-s' to specify job states,
                                  only support format 'YYYY/mm/dd HH24:MM:SS'
  -E, --endtime:                  Select jobs eligible before this time,
                                  is may be used with '-s' to specify job states,
                                  only support format 'YYYY/mm/dd HH24:MM:SS'
  --name:                         comma separated list of job names to view, default all
  -r, --partition:                comma separated list of job partitions to view, default all
  -s, --state:                    comma separated list of states to view,
                                  only support 'cd/completed', 'r/running' and 'f/failed'
  -u, --user:                     comma separated list of user names to view,
                                  this option only takes effect while user is administrator,
                                  normal user can not query other user's job
  -j, --jobs:                     comma separated list of jobs IDs to view, default is all
  
Help options:
  --help:                         show this help message
  --usage:                        display a brief summary'''
    return help_info


def get_usage_info():
    usage_info = '''Usage: sacct [--name job_names] [-S starttime] [-E endtime] [-r partition]
              [-s state] [-u user] [-j jobs] [-b]'''
    return usage_info


def log(content):
    if is_debug:
        f = open(log_file, 'a+')
        f.write(content + '\n')
        f.close()


def init_log_file_name():
    global log_file
    file_name = '%s.%d.%d' % ('sacct', os.getuid(), os.getpid())
    log_file = os.path.join('/tmp/', file_name)


# init options from environment
def init_config_with_env():
    global is_debug
    debug_env = os.getenv('SLURM_TO_DONAU_DEBUG')
    if debug_env is not None and debug_env.lower() == 'true':
        is_debug = True
        init_log_file_name()


def get_invalid_option(err_message):
    if 'not recognized' in err_message:
        options = re.findall(r'option (.*) not recognized', err_message)
        if len(options) == 1:
            return options[0]
    if 'requires argument' in err_message:
        options = re.findall(r'option (.*) requires argument', err_message)
        if len(options) == 1:
            print_single_option(options[0])
            sys.exit(1)
    return ''


def print_single_option(option):
    if len(option) < 2:
        return
    if option.startswith('--'):
        print("sacct: option '%s' requires an argument" % option)
        return
    if option[0] == '-':
        print("sacct: option requires an argument -- '%s'" % option[1:])
        return


def print_invalid_option(option):
    if len(option) == 0:
        return
    if len(option) > 1 and option[1] == '-':
        print("sacct: invalid option '" + option + "'")
    else:
        print("sacct: invalid option -- '" + option[1:] + "'")
    print('''Try "sacct --help" for more information''')


def check_time_format(input_time):
    try:
        datetime.strptime(input_time, SUPPORT_DATA_FORMAT)
    except ValueError:
        return False
    return True


def is_administrator():
    cmd = 'dadmin show config -n cluster.administrators > /dev/null'
    ret = os.system(cmd)
    return ret == 0


def convert_donau_job_states(slurm_job_states):
    donau_job_states = []
    for state in slurm_job_states:
        state_upper = state.upper()
        if state_upper == 'CD' or state_upper == 'COMPLETED':
            donau_job_states.append('SUCCEEDED')
        elif state_upper == 'R' or state_upper == 'RUNNING':
            donau_job_states.append('RUNNING')
        elif state_upper == 'F' or state_upper == 'FAILED' or state_upper == 'CA' or state_upper == 'CANCELLED':
            donau_job_states.append('FAILED')
        else:
            print("sacct: fatal: unrecognized job state value \'%s\'" % state)
            sys.exit(1)
    return donau_job_states


def parse_job_info(line):
    job_info = {}
    info = line.split()
    if len(info) < JOB_QUERY_IEN:
        return job_info
    # filter job_name
    job_name = info[JOB_QUERY_NAME_INDEX]
    if len(job_names) > 0 and job_name not in job_names:
        log(job_name + ' is not in filter job name list: ' + ' '.join(job_names))
        return job_info
    job_info[DISPLAY_NAME] = job_name
    job_info[DISPLAY_ID] = info[JOB_QUERY_ID_INDEX]
    job_info[DISPLAY_USER] = info[JOB_QUERY_USER_INDEX]
    job_info[DISPLAY_STATE] = info[JOB_QUERY_STATE_INDEX]
    job_info[DISPLAY_QUEUE] = info[JOB_QUERY_QUEUE_INDEX]
    return job_info


def print_title():
    global has_print_title
    if has_print_title:
        return
    has_print_title = True
    if is_brief:
        print(DISPLAY_SHORT_FORMAT.format(DISPLAY_JOB_ID,
                                          DISPLAY_JOB_STATE,
                                          DISPLAY_JOB_EXITCODE))
        print('------------ ---------- --------')
    else:
        print(DISPLAY_LONG_FORMAT.format(DISPLAY_JOB_ID,
                                         DISPLAY_JOB_NAME,
                                         DISPLAY_JOB_PARTITION,
                                         DISPLAY_JOB_ACCOUNT,
                                         DISPLAY_JOB_ALLOC_CPUS,
                                         DISPLAY_JOB_STATE,
                                         DISPLAY_JOB_EXITCODE))
        print('------------ ---------- ---------- ---------- ---------- ---------- --------')


def check_result_error(out_line):
    global job_err_info
    # check permission
    if 'access to job' in out_line:
        err_list = out_line.split()
        job_err_info.append(' '.join(err_list[1:]))
        return False
    if 'only permitted' in out_line:
        err_list = out_line.split()
        job_err_info.append(' '.join(err_list[1:]))
        return False
    if 'the permission' in out_line:
        err_list = out_line.split()
        job_err_info.append(' '.join(err_list[1:]))
        return False
    if out_line.startswith('---'):
        return False
    # for token expired and not exist error, it should be displayed
    if 'not exist' in out_line and 'token expired' not in out_line and 'access token' not in out_line:
        return False
    out_job_info = out_line.split()
    if len(out_job_info) > 0 and out_job_info[0].isdigit():
        return True
    if len(out_job_info) > 0 and out_job_info[0] == 'JOB_ID':
        return False
    job_err_info.append(out_line)
    return False


def check_user(user_name):
    try:
         pwd.getpwnam(user_name)
    except KeyError:
        return False
    return True


def execute_query(query_command):
    log('query_command: %s' % query_command)
    job_dict = {}
    # execute command and get job list
    output = os.popen(query_command).read()
    log('get output: %s' % output)
    job_info = output.splitlines()
    if len(job_info) == 0:
        return job_dict
    if len(job_info) == 1 and 'error' in job_info[0]:
        print(job_info[0])
        sys.exit(1)
    if 'no matches' in output:
        return job_dict
    # check result error
    valid_info = []
    for info in job_info:
        if not check_result_error(info):
            log('check djob result error: %s' % info)
            print_title()
        else:
            valid_info.append(info)

    for line in valid_info:
        info = parse_job_info(line)
        if len(info) > 0 and info[DISPLAY_ID] is not None:
            job_dict[info[DISPLAY_ID]] = info
    return job_dict


def get_job_dict():
    global djob_path
    query_command = djob_path
    # check job_users
    # -u/--user only takes effect while user is administrator
    if is_administrator():
        log('current user is administrator')
        if len(job_users) > 0:
            query_command += ' -u ' + '\'' + ' '.join(job_users) + '\''
        else:
            query_command += ' -u all '
    else:
        log('current user is not administrator')
        valid_users = []
        for user_name in job_users:
            if not check_user(user_name):
                print('sacct: error: Invalid user id: %s' % user_name)
                sys.exit(1)
            valid_users.append(user_name)
        if len(valid_users) > 0:
            query_command += ' -u ' + '\'' + ' '.join(valid_users) + '\''
    # filter job state
    donau_job_states = convert_donau_job_states(job_states)
    if len(donau_job_states) == 0:
        donau_job_states = ['SUCCEEDED', 'FAILED', 'RUNNING']
    query_command += ' -s ' + '\'' + ' '.join(donau_job_states) + '\''
    # filter job partition
    if len(job_partitions) > 0:
        query_command += ' -q ' + '\'' + ' '.join(job_partitions) + '\''
    # filter query time
    if len(job_start_time) > 0:
        query_command += ' -st ' + '\'' + job_start_time + '\''
    if len(job_end_time) > 0:
        query_command += ' -et ' + '\'' + job_end_time + '\''
    # if user specifies start time or end time, scripts should cover all the filter type in donau
    # supported filter types include 'submit', 'start', 'end'
    final_job_dict = {}
    if len(job_start_time) > 0 or len(job_end_time) > 0:
        filter_types = ['submit', 'start', 'end']
        for filter_type in filter_types:
            base_command = query_command
            base_command += ' -t ' + filter_type
            # specify job IDs
            for job_id in job_ids:
                if not job_id.isdigit():
                    print('sacct: fatal: Bad job/step specified: %s' % job_id)
                    sys.exit(1)
                else:
                    base_command += ' ' + job_id
            # get job dict by command
            job_dict = execute_query(base_command)
            final_job_dict.update(job_dict)
        return final_job_dict
    else:
        # specify job IDs
        for job_id in job_ids:
            if not job_id.isdigit():
                print('sacct: fatal: Bad job/step specified: %s' % job_id)
                sys.exit(1)
            else:
                query_command += ' ' + job_id
        return execute_query(query_command)


def parse_err_code(exit_msg):
    if 'errCode:' in exit_msg:
        msg_list = exit_msg.split()
        if len(msg_list) > 0 and msg_list[-1].isdigit():
            error_code = int(msg_list[-1])
            if error_code > 128:
                return error_code - 128
    return 0


def convert_specified_states():
    global job_states
    upper_states = []
    for state in job_states:
        if state.upper() == 'CA' or state.upper() == 'CANCELLED':
            upper_states.append('CANCELLED')
        if state.upper() == 'R' or state.upper() == 'RUNNING':
            upper_states.append('RUNNING')
        if state.upper() == 'F' or state.upper() == 'FAILED':
            upper_states.append('FAILED')
        if state.upper() == 'CD' or state.upper() == 'COMPLETED':
            upper_states.append('COMPLETED')
    return upper_states


def print_job_info(job_info):
    display_job_id = job_info.get(DISPLAY_ID)
    display_job_state = ''
    display_exit_code = ''
    donau_job_state = job_info.get(DISPLAY_STATE)
    donau_err_code = job_info.get(DISPLAY_ERR_CODE)
    if donau_job_state == 'SUCCEEDED':
        display_job_state = 'COMPLETED'
        display_exit_code = '0:0'
    elif donau_job_state == 'RUNNING':
        display_job_state = 'RUNNING'
        display_exit_code = '0:0'
    elif donau_job_state == 'FAILED':
        if donau_err_code != 0:
            display_exit_code = '0:%d' % donau_err_code
            display_job_state = 'CANCELLED'
        else:
            display_exit_code = '1:0'
            display_job_state = 'FAILED'
    specified_job_states = convert_specified_states()
    if len(specified_job_states) > 0 and display_job_state not in specified_job_states:
        return
    if is_brief:
        print(DISPLAY_SHORT_FORMAT.format(display_job_id,
                                          display_job_state,
                                          display_exit_code))
        return
    display_job_partition = job_info.get(DISPLAY_QUEUE)
    if len(display_job_partition) > JOB_PARTITION_LEN_LIMIT:
        display_job_partition = display_job_partition[:JOB_PARTITION_LEN_LIMIT]
    display_job_name = job_info.get(DISPLAY_NAME)
    if len(display_job_name) > JOB_NAME_LEN_LIMIT:
        display_job_name = display_job_name[:JOB_NAME_LEN_LIMIT]
    display_job_account = ''
    display_job_alloc_cpus = 0
    job_alloc_cpus = job_info.get(DISPLAY_ALLOC_CPUS)
    if job_alloc_cpus is not None or job_alloc_cpus.isdigit():
        display_job_alloc_cpus = int(job_alloc_cpus)
    print(DISPLAY_LONG_FORMAT.format(display_job_id,
                                     display_job_name,
                                     display_job_partition,
                                     display_job_account,
                                     display_job_alloc_cpus,
                                     display_job_state,
                                     display_exit_code))


def print_query_job_info(info_dict):
    global djob_path
    query_custom_command = '%s -x -CO "JOB_ID ALLOC_CPU EXIT_CODE EXIT_MESSAGE"' % djob_path
    for job_id in info_dict.keys():
        query_custom_command = '%s %s' % (query_custom_command, job_id)
    log('query_custom_command: %s' % query_custom_command)
    # execute command and get job list
    output = os.popen(query_custom_command).read()
    log('query detail result: %s' % output)
    query_out_list = output.splitlines()
    for line in query_out_list:
        if len(line) < JOB_CUSTOM_VALID_LEN:
            continue
        query_info = line.split()
        job_id = query_info[JOB_CUSTOM_ID_INDEX]
        if job_id is None or not job_id.isdigit():
            continue
        info = info_dict[job_id]
        # calculate allocated cpus
        if info.get(DISPLAY_ALLOC_CPUS) is None:
            info[DISPLAY_ALLOC_CPUS] = 0
        alloc_cpus = query_info[JOB_CUSTOM_ALLOC_CPUS_INDEX]
        if alloc_cpus is not None and alloc_cpus.isdigit():
            info[DISPLAY_ALLOC_CPUS] += int(alloc_cpus)
        # parse exit message
        if info.get(DISPLAY_ERR_CODE) is None:
            info[DISPLAY_ERR_CODE] = 0
        exit_code = query_info[JOB_CUSTOM_EXIT_CODE]
        # if job was terminated by user, set error code as 9
        if exit_code is not None and exit_code == '10003':
            info[DISPLAY_ERR_CODE] = 9
            continue
        if info.get(DISPLAY_STATE) is not None and info.get(DISPLAY_STATE) == 'FAILED':
            exit_message = ' '.join(query_info[JOB_CUSTOM_EXIT_MESSAGE:])
            if exit_message is not None:
                err_code = parse_err_code(exit_message)
                if info.get(DISPLAY_ERR_CODE) is not None and info.get(DISPLAY_ERR_CODE) == 0:
                    info[DISPLAY_ERR_CODE] = err_code

    job_id_list = []
    for job_id in info_dict.keys():
        job_id_list.append(int(job_id))
    job_id_list.sort()
    for job_id in job_id_list:
        job_info = info_dict.get(str(job_id))
        if job_info is not None:
            print_job_info(job_info)


def check_donau_cli_env():
    sched_cli_home = os.getenv('CCSCHEDULER_CLI_HOME')
    ccs_cli_home = os.getenv('CCS_CLI_HOME')
    if sched_cli_home is None or ccs_cli_home is None:
        print('cli env home is not configured, '
              'source configure \'profile.env\'(bash) or \'cshrc.env\'(csh) first')
        sys.exit(1)
    global djob_path
    djob_path = os.path.join(ccs_cli_home, 'bin', 'djob')
    if not os.path.exists(djob_path):
        print('djob path is not exist,'
              'source configure \'profile.env\'(bash) or \'cshrc.env\'(csh) first')
        sys.exit(1)


if __name__ == '__main__':
    if os.getuid() == 0:
        print('Permission denied. The root user is not allowed to operate.')
        sys.exit(1)
    check_donau_cli_env()
    init_config_with_env()
    short_options = '-b-S:-E:-j:-s:-u:-r:'
    long_options = ['brief', 'jobs=', 'help', 'name=', 'state=', 'user=',
                    'partition=', 'usage', 'starttime', 'endtime']
    try:
        opts, args = getopt.getopt(sys.argv[1:], short_options, long_options)
        for opt_name, opt_value in opts:
            if opt_name == '--help':
                print(get_help_info())
                sys.exit()
            if opt_name == '--usage':
                print(get_usage_info())
                sys.exit()
            if opt_name in ('-b', '--brief'):
                is_brief = True
            if opt_name == '--name':
                log("get job names from command options: %s" % opt_value)
                job_names = opt_value.split(",", -1)
            if opt_name in ('-s', '--state'):
                log("get job states from command options: %s" % opt_value)
                job_states = opt_value.split(",", -1)
            if opt_name in ('-u', '--user'):
                log("get job users from command options: %s" % opt_value)
                job_users = opt_value.split(",", -1)
            if opt_name in ('-r', '--partition'):
                log("get job partitions from command options: %s" % opt_value)
                job_partitions = opt_value.split(",", -1)
            if opt_name in ('-j', '--jobs'):
                log("get job IDs from command options: %s" % opt_value)
                job_ids = opt_value.split(",", -1)
            if opt_name in ('-S', '--starttime'):
                log("get job start time from command options: %s" % opt_value)
                if not check_time_format(opt_value):
                    print('Invalid time specification: %s' % opt_value)
                    sys.exit(1)
                job_start_time = opt_value
            if opt_name in ('-E', '--endtime'):
                log("get job end time from command options: %s" % opt_value)
                if not check_time_format(opt_value):
                    print('Invalid time specification: %s' % opt_value)
                    sys.exit(1)
                job_end_time = opt_value
        # print invalid args
        if len(args) > 0:
            log('remain arguments: %s' % ''.join(args))
            print('sacct: error: Unknown arguments:')
            for arg in args:
                print('sacct: error:  %s' % arg)
            sys.exit(1)
    except getopt.GetoptError as err:
        invalid_option = get_invalid_option(str(err))
        if len(invalid_option) != 0:
            print_invalid_option(invalid_option)
        else:
            print(str(err))
        sys.exit(1)

    job_info_dict = get_job_dict()
    print_title()
    if len(job_info_dict) > 0:
        print_query_job_info(job_info_dict)
    # Deduplicate error info
    display_err_info = list(set(job_err_info))
    for err_info in display_err_info:
        print(err_info)

